import dalex as dx
import pandas as pd
import numpy as np
import pickle
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')
import plotly
gbm = pickle.load(open("gbm.pickle", 'rb'))
svm = pickle.load(open("svm.pickle", 'rb'))
rfc = pickle.load(open("rfc.pickle", 'rb'))
gbc = pickle.load(open("gbc.pickle", 'rb'))
wines_df = pd.read_csv("winequality-red.csv")
wines_df["is_good"] = wines_df.apply(lambda row: 1 if row.quality > 5 else 0, axis = 1)
X = wines_df.drop(["quality", "is_good"], axis = 1)
y = wines_df[["is_good"]]
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify = y, random_state = 42)
explainer_gbm= dx.Explainer(gbm, X_train, y_train, verbose = 0, label="XGBoost")
explainer_svm = dx.Explainer(svm, X_train, y_train, verbose = 0, label="SVM")
explainer_rfc = dx.Explainer(rfc, X_train, y_train, verbose = 0, label="RFC")
explainer_gbc = dx.Explainer(gbc, X_train, y_train, verbose = 0, label="GBC")
pdp_gbm = explainer_gbm.model_profile()
pdp_svm = explainer_svm.model_profile()
pdp_rfc = explainer_rfc.model_profile()
pdp_gbc = explainer_gbc.model_profile()
Calculating ceteris paribus: 100%|█████████████| 11/11 [00:00<00:00, 11.39it/s] Calculating ceteris paribus: 100%|█████████████| 11/11 [00:06<00:00, 1.76it/s] Calculating ceteris paribus: 100%|█████████████| 11/11 [00:54<00:00, 4.94s/it] Calculating ceteris paribus: 100%|█████████████| 11/11 [00:01<00:00, 10.23it/s]
pdp_gbm.plot([pdp_svm, pdp_rfc, pdp_gbc])
free sulfur dioxide, gdzie inne modele praktycznie ignorują jej zmianę, a SVM wariujetotal sulfur dioxide krzywa SVM bardzo wzrasta dla skrajnych przypadków - jeszcze z EDA wiemy, że tylko pojedyncze obserwacje przyjmują wartości >200. Czy to znaczy, że SVM jest bardziej podatny na zwodnicze działanie outlierów? alcohol ponownie okazuje się mieć największy globalny wpływ na predykcję modeluale_gbm = explainer_gbm.model_profile(type = 'accumulated')
ale_svm = explainer_svm.model_profile(type = 'accumulated')
ale_rfc = explainer_rfc.model_profile(type = 'accumulated')
ale_gbc = explainer_gbc.model_profile(type = 'accumulated')
Calculating ceteris paribus: 100%|█████████████| 11/11 [00:00<00:00, 11.20it/s] Calculating accumulated dependency: 100%|██████| 11/11 [00:07<00:00, 1.45it/s] Calculating ceteris paribus: 100%|█████████████| 11/11 [00:06<00:00, 1.62it/s] Calculating accumulated dependency: 100%|██████| 11/11 [00:01<00:00, 6.84it/s] Calculating ceteris paribus: 100%|█████████████| 11/11 [00:56<00:00, 5.18s/it] Calculating accumulated dependency: 100%|██████| 11/11 [00:01<00:00, 7.47it/s] Calculating ceteris paribus: 100%|█████████████| 11/11 [00:00<00:00, 11.96it/s] Calculating accumulated dependency: 100%|██████| 11/11 [00:01<00:00, 7.35it/s]
ale_gbm.plot([ale_svm, ale_rfc, ale_gbc])
pH, ale nie są one szczególnie widoczne - pewnie dlatego, że ma ona i tak mały wpływ na predykcje. Nasza kluczowa zmienna alcohol jest z kolei mocno skorelowana jedynie z density, która też ma nikły wpływ na predykcję.plotly.offline.init_notebook_mode()